from numpy import *
import copy

class UniformCrosser(object):
   def __call__(self, orgs):
      rnd = random.random_sample(len(orgs[0])).round()
      return rnd * orgs[0] + (1 - rnd) * orgs[1]

class OnePointCrosser(object):
   def __call__(self, orgs):
      idx = random.random_integers(0, len(orgs[0]) - 1)
      return append(orgs[0][:idx], orgs[1][idx:], axis=0)
      
class TwoPointCrosser(object):
   def __call__(self, orgs):
      idx1 = random.random_integers(0, len(orgs[0]) - 1)
      idx2 = idx1
      while idx1 == idx2:
         idx2 = random.random_integers(0, len(orgs[0]) - 1)
    
      if idx1 > idx2:
         a = idx1
         idx1 = idx2
         idx2 = a

      return append(append(orgs[0][:idx1], orgs[1][idx1:idx2],axis=0), orgs[0][idx2:], axis=0)


class IntermediateCrosser(object):
   def __call__(self, orgs):
      return (1. / len(orgs)) * sum(orgs)

class Crossover(object):
   def __init__(self, selector, crossFunc, order=2):
      self.selector = selector
      self.crosser = crossFunc
      self.order = order
      
   def batch(self, popSize):
      pops = [[x for x,s in self.population1]]
      for i in xrange(self.order - 1):
         pops.append(self.selector.batch(popSize))
      return [self.crosser(orgs) for orgs in zip(*pops)]
   
   def update(self, generation, population):
      self.population1 = copy.deepcopy(population)
      # self.selector should already be updated, update is called to give us
      # the selected result
      

class Gaussian(object):
   def __init__(self, stddev):
      self.sd = stddev
      self.population = None

   def batch(self, popSize):
      return [x + random.randn(len(x)) * self.sd for x,s in self.population]
   
   def update(self, generation, population):
      self.population = population
      
class Bernoulli(object):
   def __init__(self, bitFlipProbs):
      self.bitFlipProbs = bitFlipProbs
      self.population = None

   def batch(self, popSize):
      return [abs(x - random.binomial(n=1, p=self.bitFlipProbs, size=len(x))) for x,s in self.population]
   
   def update(self, generation, population):
      self.population = population
         